#! /usr/bin/python
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn

class GIOULoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, gt_bboxes, pr_bboxes):
        """
        gt_bboxes: tensor (-1, 4) xyxy
        pr_bboxes: tensor (-1, 4) xyxy
        loss proposed in the paper of giou
        """
        gt_area = (gt_bboxes[:, 2]-gt_bboxes[:, 0])*(gt_bboxes[:, 3]-gt_bboxes[:, 1])
        pr_area = (pr_bboxes[:, 2]-pr_bboxes[:, 0])*(pr_bboxes[:, 3]-pr_bboxes[:, 1])

        # iou
        lt = torch.max(gt_bboxes[:, :2], pr_bboxes[:, :2])
        rb = torch.min(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
        TO_REMOVE = 1
        wh = (rb - lt + TO_REMOVE).clamp(min=0)
        inter = wh[:, 0] * wh[:, 1]
        union = gt_area + pr_area - inter
        iou = inter / union
        # enclosure
        lt = torch.min(gt_bboxes[:, :2], pr_bboxes[:, :2])
        rb = torch.max(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
        wh = (rb - lt + TO_REMOVE).clamp(min=0)
        enclosure = wh[:, 0] * wh[:, 1]

        giou = iou - (enclosure-union)/enclosure
        loss = 1. - giou
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        elif self.reduction == 'none':
            pass
        
        return loss 

if __name__ == '__main__':
    gt_bbox = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
    pr_bbox = torch.tensor([[2, 3, 4, 5]], dtype=torch.float32)
    loss = GIOULoss()
    lval = loss(gt_bbox, pr_bbox)
    print(lval)